#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import defaultdict
import logging
import os
from abc import ABC, abstractmethod
import re
import shutil
from typing import Dict, List, Optional, Tuple, Type

import fire
from tqdm import tqdm
import pandas as pd
from matplotlib import pyplot as plt

logging.basicConfig(level=logging.INFO)
# set default font used by matplotlib
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Tex Gyre Termes', 'Times New Roman', 'SimSun', 'serif']
})


class BenchAnalyserABC(ABC):

    def __init__(self, basedir: os.PathLike):
        '''
        Args:
            basedir: Path to direcotry containing dataset.
        '''
        self._basedir = basedir

    @staticmethod
    @abstractmethod
    def is_my_dataset(basedir: os.PathLike) -> bool:
        '''Test if this analyser apply to this directory structure.

        Args:
            basedir: Path to directory containing dataset.
        '''
        ...

    @abstractmethod
    def load(self):
        '''Loads collected data from files.
        '''
        ...

    @abstractmethod
    def visualize(self):
        '''Generate and output visualizations of loaded dataset.
        '''
        ...

__g_all_analysers: List[Type[BenchAnalyserABC]] = []
def register_analyser(subcls: Type[BenchAnalyserABC]):
    __g_all_analysers.append(subcls)


class LatencyBenchAnalyser(BenchAnalyserABC):

    def __init__(self, basedir: os.PathLike):
        super().__init__(basedir)
        def latency_df_factory():
            column_types = {
                'latency(us)': float,
            }
            df = pd.DataFrame(columns=list(column_types.keys()))
            df.astype(column_types)
            return df
        # iotype -> thread_num -> latencies
        self.latency_metrics: Dict[str, Dict[int, pd.DataFrame]] = {
            'read': defaultdict(latency_df_factory),
            'write': defaultdict(latency_df_factory)
        }

    '''
    Groups:
        thrd: Number of threads.
    '''
    __thrd_dir_regex = re.compile(r'(?P<thrd>\d+)thrd')
    '''
    Groups:
        thrdid: Thread ID.
        iotype: `read` or `write`.
    '''
    __lat_file_regex = re.compile(
        r'thrd(?P<thrdid>\d+)_'
        r'(?P<iotype>read|write)\.csv'
    )

    @staticmethod
    def is_my_dataset(basedir: os.PathLike) -> bool:
        Cls = LatencyBenchAnalyser
        for thrddir in os.listdir(basedir):
            thrdpath = os.path.join(basedir, thrddir)
            if not os.path.isdir(thrdpath):
                return False
            m = Cls.__thrd_dir_regex.match(thrddir)
            if not m:
                return False
            thrd_num = int(m.group('thrd'))
            if set(os.listdir(thrdpath)) != set().union(*(
                    {f'thrd{t}_{iotype}.csv' for t in range(thrd_num)}
                    for iotype in ['read', 'write'])):
                return False
        return True

    def load(self):
        for thrddir in os.listdir(self._basedir):
            thrdpath = os.path.join(self._basedir, thrddir)
            thrd_num = int(self.__thrd_dir_regex.match(thrddir).group('thrd'))  # type: ignore
            for tid in range(thrd_num):
                for iotype in ['read', 'write']:
                    fpath = os.path.join(thrdpath, f'thrd{tid}_{iotype}.csv')
                    self.latency_metrics[iotype][thrd_num] = pd.concat(
                        [self.latency_metrics[iotype][thrd_num], pd.read_csv(fpath)],
                        ignore_index=True)

    def visualize(self):
        NUM_BINS = 1000
        percentiles = [i / NUM_BINS for i in range(NUM_BINS)] + [1.,]
        # iotype -> thread num -> percentile
        data: Dict[str, Dict[str, List[float]]] = {
            'read': {}, 'write': {}
        }
        for iotype in self.latency_metrics:
            for thrd_num in tqdm(self.latency_metrics[iotype],
                        desc=f'Collecting percentiles for {iotype=}'):
                if not len(self.latency_metrics[iotype][thrd_num]['latency(us)']):
                    continue
                sorted_latencies = self.latency_metrics[iotype][thrd_num]['latency(us)']\
                        .sort_values(ascending=True, ignore_index=True)
                data[iotype][f'{thrd_num} thrd'] = [
                    sorted_latencies[int((len(sorted_latencies) - 1) * p)]
                    for p in percentiles
                ]
        for iotype in data:
            logging.info(f'Showing visualization for {iotype=}')
            fig = pd.DataFrame(data[iotype], index=percentiles).plot.line(alpha=0.7)
            fig.set_xlabel('Percentile')
            xticks = list(range(0, 101, 10)) + []
            fig.set_xticks(list(map(lambda p: p / 100, xticks)), list(map(lambda p: f'{p}%', xticks)))
            fig.set_xlim(0, 1)
            fig.set_ylabel('Latency (us)')
            fig.set_ylim(0, 100)
            plt.show()

register_analyser(LatencyBenchAnalyser)


class DistributedBenchAnalyser(BenchAnalyserABC):

    def __init__(self, basedir: os.PathLike):
        super().__init__(basedir)
        column_types = {
            'latency(us)': float,
        }
        self.read_latencies = pd.DataFrame(columns=list(column_types.keys()))
        self.read_latencies = self.read_latencies.astype(column_types)
        self.write_latencies = pd.DataFrame(columns=list(column_types.keys()))
        self.write_latencies = self.write_latencies.astype(column_types)

    '''
    Groups:
        nodeid: Node ID.
        thrdid: Thread ID.
        iotype: `read` or `write`.
    '''
    __lat_file_regex = re.compile(
        r'node(?P<nodeid>\d+)_'
        r'thrd(?P<thrdid>\d+)_'
        r'(?P<iotype>read|write)\.csv'
    )

    @staticmethod
    def is_my_dataset(basedir: os.PathLike) -> bool:
        Cls = DistributedBenchAnalyser
        items = os.listdir(basedir)
        for f in items:
            p = os.path.join(basedir, f)
            if (not os.path.isfile(p) or
                    not Cls.__lat_file_regex.match(f)):
                return False
        return True

    def load(self):
        for f in os.listdir(self._basedir):
            iotype = self.__lat_file_regex.match(f).group('iotype') # type: ignore
            loaded = pd.read_csv(os.path.join(self._basedir, f))
            if iotype == 'read':
                self.read_latencies = pd.concat([self.read_latencies, loaded], ignore_index=True)
            elif iotype == 'write':
                self.write_latencies = pd.concat([self.write_latencies, loaded], ignore_index=True)
            else:
                raise RuntimeError(f'Unknown {iotype=}')

    def visualize(self):
        NUM_BINS = 1000
        percentiles = [i / NUM_BINS for i in range(NUM_BINS)] + [1.,]
        data: Dict[str, List[float]] = {}
        if len(self.read_latencies):
            sorted_latencies = self.read_latencies['latency(us)'].sort_values(
                    ascending=True, ignore_index=True)
            data['read'] = [
                sorted_latencies[int((len(sorted_latencies) - 1) * p)]
                for p in tqdm(percentiles, desc='Calculating percentiles for read latency')
            ]
        if len(self.write_latencies):
            sorted_latencies = self.write_latencies['latency(us)'].sort_values(
                    ascending=True, ignore_index=True)
            data['write'] = [
                sorted_latencies[int((len(sorted_latencies) - 1) * p)]
                for p in tqdm(percentiles, desc='Calculating percentiles for write latency')
            ]
        df = pd.DataFrame(data=data, index=percentiles)
        fig = df.plot.line(alpha=0.7)
        fig.set_xlabel('Percentile')
        xticks = list(range(0, 101, 10)) + []
        fig.set_xticks(list(map(lambda p: p / 100, xticks)), list(map(lambda p: f'{p}%', xticks)))
        fig.set_xlim(0, 1)
        fig.set_ylabel('Latency (us)')
        fig.set_ylim(0, 70)
        plt.show()

register_analyser(DistributedBenchAnalyser)


############################################################

def main(basedir: os.PathLike):
    '''
    Visualize collected per-operation latency metrics.

    Args:
        basedir: Directory storing recorded latencies.
        outdir: Output directory.
    '''
    # basic checks
    if not os.path.isdir(basedir):
        raise FileNotFoundError(f'{basedir=}')
    # select a compatible analyser
    analyser: Optional[BenchAnalyserABC] = None
    for anacls in __g_all_analysers:
        if anacls.is_my_dataset(basedir):
            analyser = anacls(basedir)
            break
    if analyser is None:
        raise RuntimeError(f'Could not judge dataset type. {basedir=}')
    logging.info(f'Analyser guessed based on directory is {analyser.__class__.__name__}.')
    # analyse and visualize metrics
    analyser.load()
    analyser.visualize()

if __name__ == '__main__':
    fire.Fire(main)
